
#include <sys/cdefs.h>
#define MPICH_SKIP_MPICXX
#include "esmd_types.h"
#include "cell.h"
#include "swarch.h"
#include "rigid.hpp"
#include "shake.hpp"
#include "lincs.hpp"
#include "cal.h"
extern vec<real> *frep;
extern cal_lock_t *locks;

template <typename T>
struct rigid_param {
  cellgrid_t *grid;
  real dtfsq;
  real dtv;
  T rigid;
};
struct update_param {
  cellgrid_t *grid;
  real dtfsq;
  real dtv;
};
__always_inline int is_ghost(cellgrid_t *grid, int x, int y, int z) {
  return (x < 0 || x >= grid->nlocal.x || y < 0 || y >= grid->nlocal.y || z < 0 || z >= grid->nlocal.z);
}

#ifdef __sw_host__
#include <qthread.h>
extern void slave_unconstrained_update_cpe(update_param *pm);
void unconstrained_update_sw(cellgrid_t *grid, real dtv, real dtfsq) {
  update_param pm;
  pm.grid = grid;
  pm.dtfsq = dtfsq;
  pm.dtv = dtv;

  qthread_spawn(slave_unconstrained_update_cpe, &pm);
  qthread_join();
}
// template <typename T> extern void slave_rigid_force (rigid_param<T> *pm);
template <typename T> extern void slave_rigid_force_listed(rigid_param<T> *pm);
// template <class T>
// void rigid_force_sw(cellgrid_t *grid, real dtfsq, T &rigid) {
//   rigid_param<T> pm;
//   pm.grid = grid;
//   pm.dtfsq = dtfsq;
//   pm.rigid = rigid;
//   *lbal_cnt = 0;
//   qthread_spawn(slave_rigid_force<T>, &pm);
//   qthread_join();
// }
// template void rigid_force_sw<lincs>(cellgrid_t *grid, real dtfsq, lincs &rigid);
// template void rigid_force_sw<shake> (cellgrid_t *grid, real dtfsq, shake &rigid);
template<typename T>
void rigid_force_listed_sw(cellgrid_t *grid, real dtfsq, T &rigid) {
  rigid_param<T> pm;
  pm.grid = grid;
  pm.dtfsq = dtfsq;
  pm.rigid = rigid;
  *lbal_cnt = 0;
  if (frep == NULL) frep = esmd::allocate<vec<real>>(grid->nall.vol() * CELL_CAP, "listed/frep");
  if (locks == NULL) {
    locks = esmd::allocate<cal_lock_t>(grid->nall.vol() * CELL_CAP, "listed/locks");
    memset(locks, 0, grid->nall.vol() * CELL_CAP * sizeof(cal_lock_t));
  }

  qthread_spawn(slave_rigid_force_listed<T>, &pm);
  qthread_join();

}
template void rigid_force_listed_sw<lincs> (cellgrid_t *grid, real dtfsq, lincs &rigid);
template void rigid_force_listed_sw<shake>(cellgrid_t *grid, real dtfsq, shake &rigid);
#endif
#ifdef __sw_slave__
#include <qthread_slave.h>
#include "dma_macros_new.h"
#include "dma_funcs.hpp"

#define LWPF_UNIT U(RIGID)
#define LWPF_KERNELS K(SCATTER) K(GATHER) K(COMPUTE)
#define EVT_PC0 PC0_CNT_INST
#define EVT_PC2 PC2_CNT_GLD
#define EVT_PC3 PC3_CYCLE
#include "lwpf3/lwpf.h"
#include "memptr.hpp"
template<int NBLKS, int BLKSIZE>
__always_inline void cell_f_cache<NBLKS, BLKSIZE>::sum_rigid(cellgrid_t *grid){
    qthread_syn();
    FOREACH_CELL_CPE_RR(grid, i, j, k, cell) {
      int offset = get_offset_xyz<true>(grid, i, j, k);
      int natom = cell->natom;
      if (is_ghost(grid, i, j, k)){
        memptr_t<vec<real>, CELL_CAP, MP_IN> frep(f + offset * CELL_CAP, natom);
        dma_putn(cell->f, &frep[0], natom);
      } else {
        memptr_t<vec<real>, CELL_CAP, MP_INOUT> fcell(cell->f, natom);
        memptr_t<vec<real>, CELL_CAP, MP_IN> frep(f + offset * CELL_CAP, natom);
        for (int i = 0; i < natom; i ++){
          fcell[i] += frep[i];
        }
        // memptr_t<vec<real>, CELL_CAP, MP_IN> frep(f + offset * CELL_CAP, natom);
        // dma_putn(cell->f, &frep[0], natom);
      }
    }
  }
template <typename T>
void rigid_force_listed(rigid_param<T> *gl_pm) {
  lwpf_enter(RIGID);
  lwpf_start(COMPUTE);
  //lwpf_start(SCATTER);
  rigid_param<T> pm;
  dma_getn(gl_pm, &pm, 1);
  cellgrid_t lgrid;
  dma_getn(pm.grid, &lgrid, 1);
  //scatter_guest_x(&lgrid);
  real rdtfsq = 1. / gl_pm->dtfsq;
  qthread_syn();
  //lwpf_stop(SCATTER);

  cell_x_cache<32, 8, 1> xcache(lgrid.cells);
  cell_f_cache<32, 8> fcache(frep, locks);
  fcache.fill(&lgrid);
  int nneighbor = hdcell(lgrid.nn, lgrid.nn, lgrid.nn, lgrid.nn) + 1;

  FOREACH_LOCAL_CELL_CPE_DYN(pm.grid, cx, cy, cz, icell) {
    cellmeta_t imeta;
    lwpf_start(GATHER);
    dma_getn(&icell->basis, &imeta, 1);
    vec<real> x[CELL_CAP + MAX_CELL_GUEST], f[CELL_CAP + MAX_CELL_GUEST], xuc[CELL_CAP + MAX_CELL_GUEST];
    rigid_rec rigid[CELL_CAP];
    real rmass[CELL_CAP + MAX_CELL_GUEST];

    dma_getn(icell->x, x, imeta.natom);
    dma_getn(icell->shake_xuc, xuc, imeta.natom);
    dma_getn(icell->rigid, rigid, imeta.natom);
    dma_getn(icell->f, f, imeta.natom);
    dma_getn(icell->rmass, rmass, imeta.natom);
    dma_getn(icell->rmass + CELL_CAP, rmass + CELL_CAP, imeta.nguest);
    auto first_guest_cell = array_in(icell->first_guest_cell, nneighbor + 1);
    auto guest_id = array_in(icell->guest_id, imeta.nguest);
    FOREACH_NEIGHBOR(&lgrid, cx, cy, cz, dx, dy, dz, jcell){
      int did = hdcell(lgrid.nn, dx, dy, dz);
      if (first_guest_cell[did] == first_guest_cell[did + 1]) continue;
      int joff = get_offset_xyz<true>(lgrid, cx + dx, cy + dy, cz + dz);
      cellmeta_t jmeta = fetch_ptr((cellmeta_t*)&jcell->basis);
      vec<real> dcell = jmeta.basis - imeta.basis;
      for (int i = first_guest_cell[did]; i < first_guest_cell[did + 1]; i ++){
        auto [xj, xucj] = xcache(joff, guest_id[i]);
        x[CELL_CAP + i] = xj + dcell;
        xuc[CELL_CAP + i] = xucj + dcell;
        f[CELL_CAP + i] = 0;
      }
    }
    lwpf_stop(GATHER);
    // for (int i = 0; i < imeta.nguest; i++) {

    //   vecset1(f[CELL_CAP + i], 0);
    // }
    for (int i = 0; i < imeta.natom; i++) {
      int idx[4];
      idx[0] = i;
      idx[1] = rigid[i].id[0];
      idx[2] = rigid[i].id[1];
      idx[3] = rigid[i].id[2];
      switch (rigid[i].type) {
      case RIGID2:
        pm.rigid.template run<rigid2>(f, rdtfsq, x, xuc, rmass, rigid[i], idx);
        break;
      case RIGID3:
        pm.rigid.template run<rigid3>(f, rdtfsq, x, xuc, rmass, rigid[i], idx);
        break;
      case RIGID3ANGLE:
        pm.rigid.template run<rigid3angle>(f, rdtfsq, x, xuc, rmass, rigid[i], idx);
        break;
      case RIGID4:
        pm.rigid.template run<rigid4>(f, rdtfsq, x, xuc, rmass, rigid[i], idx);
        break;
      default:
        continue;
      }
    }
    dma_putn(icell->f, f, imeta.natom);
    lwpf_start(SCATTER);
    FOREACH_NEIGHBOR(&lgrid, cx, cy, cz, dx, dy, dz, jcell){
      int did = hdcell(lgrid.nn, dx, dy, dz);
      if (first_guest_cell[did] == first_guest_cell[did + 1]) continue;
      int joff = get_offset_xyz<true>(lgrid, cx + dx, cy + dy, cz + dz);
      cellmeta_t jmeta = fetch_ptr((cellmeta_t*)&jcell->basis);
      vec<real> dcell = jmeta.basis - imeta.basis;
      for (int i = first_guest_cell[did]; i < first_guest_cell[did + 1]; i ++){
        fcache(joff, guest_id[i]) += f[CELL_CAP + i];
      }
    }
    lwpf_stop(SCATTER);
  }
  qthread_syn();

  fcache.flush();
  fcache.sum_rigid(&lgrid);

  lwpf_stop(COMPUTE);
  lwpf_exit(RIGID);
}
// template void rigid_force<lincs>(rigid_param<lincs>*);
// template void rigid_force<shake> (rigid_param<shake> *);
template void rigid_force_listed<lincs> (rigid_param<lincs> *);
template void rigid_force_listed<shake> (rigid_param<shake> *);

void unconstrained_update_cpe(update_param *pm) {
  
  update_param lpm;
  dma_getn(pm, &lpm, 1);
  cellgrid_t lgrid;
  dma_getn(lpm.grid, &lgrid, 1);
  real dtfsq = lpm.dtfsq;
  real dtv = lpm.dtv;
  FOREACH_LOCAL_CELL_CPE_RR(&lgrid, ii, jj, kk, cell) {
    int natom = cell->natom;
    vec<real> xuc[CELL_CAP], x[CELL_CAP], v[CELL_CAP], f[CELL_CAP];
    real rmass[CELL_CAP];
    dma_getn(cell->x, x, natom);
    dma_getn(cell->v, v, natom);
    dma_getn(cell->f, f, natom);
    dma_getn(cell->rmass, rmass, natom);
    for (int i = 0; i < natom; i++) {
      xuc[i].x = x[i].x + dtv * v[i].x + f[i].x * dtfsq * rmass[i];
      xuc[i].y = x[i].y + dtv * v[i].y + f[i].y * dtfsq * rmass[i];
      xuc[i].z = x[i].z + dtv * v[i].z + f[i].z * dtfsq * rmass[i];
    }
    dma_putn(cell->shake_xuc, xuc, natom);
  }
}
#endif
